import gdal
import os

import numpy
import numpy as np
import glob
import csv
import utility_functions as utility_functions
import fiona
import rasterio

from rasterio.mask import mask
import argparse

def clip_raster(input, umask, output):
    #D:\Users\cwmueller\data\UHI\FSD\output_folder\CITY BOUNDARIES
    with fiona.open(umask, "r") as shapefile:
        shapes = [feature["geometry"] for feature in shapefile]

    with rasterio.open(input) as src:
        out_image, out_transform = mask(src, shapes, crop=True)
        #out_image, out_transform = mask(src, shapes, crop=True,pad=False, pad_width=0.0,all_touched=False)
        out_meta = src.meta
        out_image = np.ma.masked_where(out_image == -9999, out_image)

    out_meta.update({"driver": "GTiff",
                     "height": out_image.shape[1],
                     "width": out_image.shape[2],
                     "transform": out_transform,
                     #'dtype': 'uint8',
                     #'nodata': 0
                     'nodata': -9999
                     })
    print(out_image.shape)
    #out_image = out_image.data.astype('uint8')

    with rasterio.open(output, "w", **out_meta) as dest:
        dest.write(out_image)

def savetiff(workdir, raster, array, trans, prj, outputid):
    hotspot_out = workdir + os.sep + outputid

    if not os.path.exists(hotspot_out):
        os.mkdir(hotspot_out)
    save_path = hotspot_out + os.sep + os.path.basename(raster)[:-4] + '_' + outputid + '.tif'

    y_res, x_res = array.shape
    
    #utility_functions.saveGeoTiff(save_path, array, [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

    new_array = array
    new_array[(new_array == -9999)] = 0
    new_array = new_array.astype(np.uint16)
    utility_functions.saveGeoTiff(save_path, new_array, [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_UInt16)
    
    #utility_functions.saveGeoTiff(save_path, array, [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)
    
    return save_path

#Load UHI Urban Mask
def hotspot(workdir, raster_path, iata,uhi_lc_root,lc_clipped_folder,ard_city_dissolved_shapefile):

    #workdir is gonna be like /caldera/projects/usgs/eros/urban_heat_islands/Phase1/FSD/output_folder

    #for a given stack pixel value, identify it's urban code and landclass and determine the lc MAXLST + 1 SD
    #return a value of 1 for everypixel that exceeds it's respective MAXLST + 1 SD (by class per urban area)
    
    ### LOAD INPUT RASTERS ###
    uz_file = glob.glob(workdir + os.sep + "UrbanZones" + os.sep + "*" + iata + ".tif")[0]
    uz_raster = gdal.Open(uz_file)
    uz_array = np.array(uz_raster.ReadAsArray())
    #print(uz_array.dtype)

    fn_raster = gdal.Open(raster_path)
    #geot = fn_raster.GetGeoTransform()
    fn_array = np.array(fn_raster.ReadAsArray())
    trans, prj = utility_functions.get_geo(raster_path)
    #print(fn_array.dtype)

    #HERE1
    # load UHI_LC
    # determine what UHI_LC layer to use (if STACK, use most recently available layer)
    
    lc_rasters = glob.glob(uhi_lc_root + "*.tif")

    #Set year to most recent dataset if input is STACK data, otherwise set year to match the input file
    max_lcyr = 0
    year = os.path.basename(raster_path)[11:15]
    for lc in lc_rasters:
        yr = int(os.path.basename(lc)[11:15])
        if yr > max_lcyr:
            max_lcyr = yr

    if year == "STAC":
        year = str(max_lcyr)

    print(year)
    # LST_CLASS_STATS_ANNUAL is generated by phase2_uhi_stats_classbased.py
    # these appear in the phase1 city output folder, eg
    # /caldera/projects/usgs/eros/urban_heat_islands/Phase1/FSD/output_folder
    print(workdir + os.sep + "LST_CLASS_STATS_ANNUAL" + os.sep + "ByCSV" + os.sep + "*" + year + "*" +  "_ZMEAN.csv")
    stats_csv = glob.glob(workdir + os.sep + "LST_CLASS_STATS_ANNUAL" + os.sep + "ByCSV" + os.sep + "*" + year + "*" +  "_ZMEAN.csv")[0]
    print(stats_csv)
    with open(stats_csv) as csv_file:
        data = list(csv.reader(csv_file))

    #HERE2
    #UHI_LC* rasters are the output of step 332
    #eg they would be in a place like /caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_332/output/FSD

    lc_file = glob.glob(uhi_lc_root + "*" + str(year) + "*.tif")[0]
    print(lc_file)

    lc_clipped_file = lc_clipped_folder + os.sep + os.path.basename(lc_file)[:-4] + '_clipped.tif'
    print(lc_clipped_file)

    umask = ard_city_dissolved_shapefile


    clip_raster(lc_file,umask, lc_clipped_file)

    #hack to test if clip_raster is causing the +1 which breaks
    #lc_clipped_file = lc_file


    lc_clipped_raster = gdal.Open(lc_clipped_file)

    #lc_array = np.array(lc_clipped_raster.ReadAsArray()[:-1])
    lc_array = np.array(lc_clipped_raster.ReadAsArray())

    ### GENERATE A HOTSPOT RASTER ###
    # Anywhere a MAXLST pixel is greater than the same urban zones class MAX + 1 SD

    uhi_id_list = []
    # Create list of UHI ID's
    for n in range(1, len(data)):
        if data[n][0] not in uhi_id_list:
            uhi_id_list = uhi_id_list + [data[n][0]]

    #the lc_array is +1 on both x and y
    #slice it down
    shape = fn_array.shape
    lc_array = lc_array[0:shape[0],0:shape[1]]


    no_data = -9999

    # As a recap, intensity is the difference of a single pixels LST (e.g. annual MAXLST raster) and the 
    # zonal mean of all non urban pixels (for the same year - I believe we calculate this from the CSV). --- 
    # so for this you mean, take a single LST pixel, then get all of the zonal pixels that are less than 24, 
    # average them, and subtract it from the current single LST pixel?
    #Non urban zonal pixels, so that would be anything less than 20
    #we actually have a column in our statistics called nulst_mean...that's non urban LST mean
    #so the nulst_mean SHOULD equal the mean of all non urban pixels (value less than 20 ....assuming there are no values of 1 anymore). 
    # #Then for intensity, it would be an individual pixel from annual LST (that's classified as urban) minus that nulst_mean value
    diff_array = np.where(lc_array > 9, fn_array, no_data)

    print("Processing Hotspot, Urban vs Non-Urban, and Normalized by LC analyses")

    for n in range(1, len(data)):

        UHI_ID = data[n][0]
        lc_class = data[n][1]

        print(" Processing for LC class: " + str(lc_class))

        #look0 = float(data[n][9])
        diff = (fn_array - float(data[n][9]))

        #look1 = int(UHI_ID[-4:])
        diff_array = np.where((uz_array == int(UHI_ID[-4:])) & (lc_array > 9), diff, diff_array)

        ABC = None

    diff_array[diff_array < -500] = no_data #-500 is a temporary quick solution - need to exclude no_data pixels in above calculations

    saved_path = savetiff(workdir, raster_path, diff_array, trans, prj, 'INTENSITY_ANNUAL')#dont forget to make this output 8 bit?

    ABC = None


if __name__ == '__main__':
    # https://stackabuse.com/command-line-arguments-in-python/
    #parser = argparse.ArgumentParser()

    #chases originals
    #workdir = r'\\igskmncnfs016.cr.usgs.gov\lsrdfs1\UHI\data\Pre HPC Testing'
    #raster_list = glob.glob(workdir + r'\LST_ANNUAL_CLIPPED\UHI_CU_FSD_*.tif')
    #iata = 'FSD'
    #uhi_lc_root = workdir + os.sep + "UHI_LC" + os.sep + iata + os.sep
    #lc_clipped_folder = r'D:\Users\cwmueller\data\UHI\FSD\output_folder\LC_clipped'
    #ard_city_dissolved_shapefile = r'D:\Users\cwmueller\data\UHI\FSD\output_folder\Shapefiles\ard_fsd_city_dissolved.shp'


    '''
    parser = argparse.ArgumentParser()
    parser.add_argument("--workdir", "-w", help="specify workpath directory")
    parser.add_argument("--iata_location_identifier", "-i", help="iata_location_identifier")
    parser.add_argument("--uhi_lc_root", "-m", help="metrobuffer_shapefile_root") 
    args = parser.parse_args()
    workdir = args.workdir
    iata = args.iata_location_identifier
    uhi_lc_root = args.uhi_lc_root
    '''

    #rezas hpc hardcoded
    #workdir = "/caldera/projects/usgs/eros/urban_heat_islands/Phase1/FSD/output_folder"
    #LST_ANNUAL_CLIPPED comes from phase2_1_uhi_urbanbuffer_clip.py appears in this folder /caldera/projects/usgs/eros/urban_heat_islands/Phase1/MSO/output_folder
    #iata = 'FSD'
    #uhi_lc_root = "/caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_332/output" + os.sep + iata + os.sep

    workdir = "/caldera/projects/usgs/eros/urban_heat_islands/Phase1/FSD/output_folder"
    #LST_ANNUAL_CLIPPED comes from phase2_1_uhi_urbanbuffer_clip.py appears in this folder /caldera/projects/usgs/eros/urban_heat_islands/Phase1/MSO/output_folder
    iata = 'FSD'
    uhi_lc_root = "/caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_332/output" + os.sep + iata + os.sep



    match_string = workdir + os.sep + r'LST_ANNUAL_CLIPPED'+ os.sep +'UHI_CU_'+iata+'_*.tif'
    raster_list = glob.glob(match_string)
    ard_city_dissolved_shapefile = "/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_metrobuffer_generation/City_Boundaries_Buffer/ard_"+iata+"_city_dissolved.shp"
    lc_clipped_output_folder = workdir + os.sep + 'LC_clipped'

    if not os.path.exists(lc_clipped_output_folder):
        os.mkdir(lc_clipped_output_folder)

    for raster in raster_list:
        hotspot(workdir, raster, iata,uhi_lc_root,lc_clipped_output_folder,ard_city_dissolved_shapefile)

    #raster = 'D:\\Users\\cwmueller\\data\\UHI\\FSD\\output_folder\\STACK\\UHI_CU_FSD_STACK_20211118_C01_MAXLST_MEAN.tif'
    #workdir = r'D:\Users\cwmueller\data\UHI\FSD\output_folder'
    #iata = 'FSD'
    #hotspot(workdir, raster, iata)